from agent_personas import *
import os
import pickle
import re
import argparse
from model_definitions import llama_8b_model, mistral_model
import glob
import baseline_prompts
from baseline_prompts import yaml_cqg, yaml_student_answer
from tqdm import tqdm

def parse_lines_yaml(yaml, tag='state_attribute'):
    try:
        line_exp = re.findall(fr'{tag}_\d*:\s*\n*(.+)', yaml, re.IGNORECASE)
    except:
        yaml = yaml.replace(":\n", ": ")
        line_exp = []
        for line in yaml.split('\n'):
            try:
                line.index('_')
                i = line.index(': ')
                line_exp.append(line[i+2:])
            except:
                continue

    return line_exp

def parse_yes_no(yaml, x):
    # return answer
    explanation = re.findall(r'explanation:\s*\n*(.+)', yaml, re.IGNORECASE)[0]
    answer = re.findall(fr'{x}:\s*\n*(.+)', yaml, re.IGNORECASE)[0] == 'True'

    print(f'EXPLANATION: {explanation}')
    print(f'ANSWER: {answer}')
    return explanation, answer

def parse_bg_iso(message, bug_fixes):
    fix_results = re.findall(r'correct_bug_fix_\d+_present:\s*\n*(.+)', message, re.IGNORECASE)
    overall = True
    for r in fix_results:
        if "False" in r:
            overall = False
    return overall and (len(fix_results) == len(all_bug_fixes))

def log(text):
    with open(os.path.join(LOG_FOLDER, FILE_NAME, "log.txt"), 'a+') as f:
        f.write('\nxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n')
        f.write(text)
        f.write('\nxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n')

class Instructor():
    def __init__(self):
        self.init_prompt = instructor_persona + problem_statement + buggy_code
        self.model = llama_8b_model

    def prompt_instructor(self, prompt):
        messages = [
            {"role": "system", "content": self.init_prompt},
            {"role": "user", "content": prompt}]
        
        model_prompt = self.model.tokenizer.apply_chat_template(messages, 
                                                        tokenize=False, 
                                                        add_generation_prompt=True)

        terminators = [
            self.model.tokenizer.eos_token_id,
            self.model.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = self.model(
            model_prompt,
            max_new_tokens=1024,
            eos_token_id=terminators,
            do_sample=True,
            temperature=0.3,
            top_p=0.9,
            pad_token_id=self.model.tokenizer.eos_token_id
        )
        message = outputs[0]["generated_text"][len(model_prompt):]

        print('prompted instructor')
        log('--- TO INSTRUCTOR: ' + prompt)
        log(f'--- FROM INSTRUCTOR: {message}')
        return message
        
    
    def generate_candidate_questions(self, level):
        # conditional question genration
        if level == 0:
            prompt = baseline_prompts.i2i_cqg_code(1, bug_fixes, bug_description) + yaml_cqg
        else:
            prompt = baseline_prompts.i2i_qg(num_questions, '\n'.join(convo_history), bug_fixes, bug_description) + yaml_cqg

        candidate_questions = self.prompt_instructor(prompt)
        candidate_questions = parse_lines_yaml(candidate_questions, tag='question')

        return candidate_questions
    
    def format_bug_fixes(self, fixes, prefix, counter):
        result = ""
        for f in fixes:
            result += f'{prefix}_{counter}: {f}\n'
            counter = chr(ord(counter)+1)
            
        return result
    
    def check_bug_fixes(self, student_bug_fixes):
        x = "are_fixes_isomorphic"

        correct_bf = re.findall(r'^---\nbug_fixes:\n([\S\s]*)\n---\n$', bug_fixes, re.IGNORECASE)[0].split("\n")

        student_bf_formatted = self.format_bug_fixes(student_bug_fixes, 'suggested_bug_fix', 'a')
        correct_bf_formatted = self.format_bug_fixes(correct_bf, 'correct_bug_fix', '1')

        messages = [
            {"role": "system", "content": problem_statement},
            {"role": "user", "content": baseline_prompts.v2v_check_isomorphic_eq_bf(student_bf_formatted, correct_bf_formatted, len(correct_bf))}]
        
        model_prompt = self.model.tokenizer.apply_chat_template(messages, 
                                                        tokenize=False, 
                                                        add_generation_prompt=True)

        terminators = [
            self.model.tokenizer.eos_token_id,
            self.model.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = self.model(
            model_prompt,
            max_new_tokens=1024,
            eos_token_id=terminators,
            do_sample=False,
            pad_token_id=self.model.tokenizer.eos_token_id
        )

        message = outputs[0]["generated_text"][len(model_prompt):]

        print('prompted verifier state')
        log('--- TO INSTRUCTOR: ' + model_prompt)
        log(f'--- FROM INSTRUCTOR: {message}')

        return parse_bg_iso(message, bug_fixes)
    
class Student():
    def __init__(self):
        self.model = mistral_model
        self.init_prompt = student_persona + problem_statement + buggy_code
    
    def prompt_student(self, prompt, suffix="", do_sample=True):
        final_prompt = "<s>[INST]" + self.init_prompt + prompt + "[/INST]" + suffix
        if do_sample:
            response = self.model(final_prompt, 
                                do_sample=do_sample,
                                top_k=10,
                                num_return_sequences=1, 
                                max_new_tokens=200,
                                pad_token_id=self.model.tokenizer.eos_token_id)[0]
        else:
            response = self.model(final_prompt, 
                                do_sample=do_sample,
                                num_return_sequences=1, 
                                max_new_tokens=200,
                                pad_token_id=self.model.tokenizer.eos_token_id)[0]
        
        print('prompted student')
        log('--- TO STUDENT: ' + prompt)
        log('--- FROM STUDENT: ' + response['generated_text'].split('[/INST]')[-1])
        return response['generated_text']
    
    def parse_student_answer(self, yaml):
        yaml = yaml.replace(":\n", ": ").replace(": \n", ": ")
        yaml = yaml.split('[/INST]')[-1]
        yaml = yaml.split("student_answer:")[-1]
        return yaml

    def ask_student(self, question):
        response = self.prompt_student(prompt=question + yaml_student_answer, suffix="\nstudent_answer: ")
        return self.parse_student_answer(response)

    def generate_bug_fixes(self, convo_history):
        prompt = baseline_prompts.i2s_generate_bug_fixes('\n'.join(convo_history))
        yaml = self.prompt_student(prompt, do_sample=False)
        yaml = yaml.split('[/INST]')[-1]
        yaml = re.findall(r'bug_fix_.:\s*(.*)', yaml, re.IGNORECASE)
        return yaml

def run():
    global problem_statement, correct_code, buggy_code, bug_fixes, bug_description, LOG_FOLDER, FILE_NAME, convo_history
    log(f"problem statement:\n{problem_statement}\nbuggy_code:\n{buggy_code}\ncorrect_code:\n{correct_code}\nbug_fixes:\n{bug_fixes}")

    instructor = Instructor()
    student = Student()
    
    max_levels = 20

    student_responses = []

    for i in range(max_levels):

        # Instructor asks a question 
        candidate_questions = instructor.generate_candidate_questions(i)
        current_question = candidate_questions[0]

        
        print("Current Question = ", current_question)
        convo_history.append("TEACHER: " + current_question)

        # Student Response 
        student_response = student.ask_student(current_question)
        convo_history.append("STUDENT: " + student_response)
        print("Student Reponse = ", student_response)

        student_responses.append(student_response)

        # prompt bug fixes
        student_bug_fixes = student.generate_bug_fixes(convo_history)

        # check if bug fixes are all contained within ground truth?

        all_fixes = instructor.check_bug_fixes(student_bug_fixes)

        if all_fixes:
            break
    
    final_conv_history = "\n========================\n".join(convo_history)

    return student_bug_fixes, final_conv_history



if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--file', type=str, default='data_pkls/')
    parser.add_argument('--bug_num', type=int, default=1)
    parser.add_argument('--log_folder', type=str, default='emnlp/single_bug_llama_vanilla')

    args = parser.parse_args()
    LOG_FOLDER = args.log_folder

    try:
        os.mkdir(f'{LOG_FOLDER}')
    except:
        hi = 9

    files = glob.glob(f'{args.file}/*.pkl')
    for f in tqdm(files):
    
        FILE_NAME = re.findall(r'/(.+).pkl', f, re.IGNORECASE)[0]
        print("HELLO", FILE_NAME)
        if os.path.exists(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'bug_fixes.txt')):
            continue
        else:
            try:
                os.mkdir(f'{LOG_FOLDER}/{FILE_NAME}')
            except OSError:
                if os.path.exists(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'log.txt')):
                    os.remove(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'log.txt'))

                if os.path.exists(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'bug_fixes.txt')):
                    os.remove(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'bug_fixes.txt'))

                if os.path.exists(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'convo.txt')):
                    os.remove(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'convo.txt'))
                

            extracted_data = pickle.load(open(f, 'rb'))
            problem_statement = extracted_data['problem']
            buggy_code = extracted_data['buggy_code']

            bug_fixes = extracted_data['bug_fixes']
            all_bug_fixes = re.findall(r'^---\nbug_fixes:\n([\S\s]*)\n---\n$', bug_fixes, re.IGNORECASE)[0].split("\n")
            if args.bug_num == 1:
                first_bug_fix = all_bug_fixes[0]
                bug_fixes = re.sub(r'^---\nbug_fixes:\n[\S\s]*\n---\n$', f'---\nbug_fixes:\n{first_bug_fix}\n---\n', bug_fixes)
            
            bug_description = extracted_data['bug_desc'] # not a typo
            correct_code = extracted_data['correct_code']
            unit_tests = ''#extracted_data['unit_tests']

            convo_history = []
            num_questions = 1

            suggested_bug_fixes, final_conv_history = run()

            if suggested_bug_fixes:
                with open(f"{LOG_FOLDER}/{FILE_NAME}/bug_fixes.txt", 'w') as file:
                    file.write("\n".join(suggested_bug_fixes))

            if final_conv_history:
                with open(f"{LOG_FOLDER}/{FILE_NAME}/convo.txt", 'w') as file:
                    file.write(final_conv_history)


        # except:
        #     with open('FAILURE_CASES.txt', 'a+') as f:
        #         f.write(f'{FILE_NAME}\n')
